from sentence_transformers import SentenceTransformer

import torch
from load_list_data import *
import sys


dataset_folder = "../processed_data/"
datasets = ['cora']
target = 'gpt_only_embedding' # 'content_embedding' / 'gpt_response_embedding' / 'gpt_only_embedding'

for dataset in datasets:
    title_list = load_title_list(dataset_folder, dataset)
    content_list = load_content_list(dataset_folder, dataset)
    gpt_response_list = load_gpt_response(f'../raw_data/{dataset}', dataset)
    print(len(title_list), len(content_list), len(gpt_response_list))

    if target == 'content_embedding':
        sentences = [content_list[i] + '\n' + title_list[i] for i in range(len(title_list))]
    elif target == 'gpt_response_embedding':
        sentences = [content_list[i] + '\n' + gpt_response_list[i] + '\n' + title_list[i] for i in range(len(title_list))]
    elif target == 'gpt_only_embedding':
        sentences = [gpt_response_list[i] + '\n' + title_list[i] for i in range(len(title_list))]

    # model = SentenceTransformer('sentence-transformers/sentence-t5-base')
    model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    embeddings = model.encode(sentences)
    torch.save(embeddings, f"{dataset_folder}{dataset}/{dataset}_{target}_list.pt")